import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

class GroupLinearLayer(nn.Module):
    """Modularized Linear Layer"""
    def __init__(self, num_blocks, din, dout, bias=True):
        super(GroupLinearLayer, self).__init__()

        self.bias=bias
        self.w = nn.Parameter(torch.Tensor(num_blocks, din, dout))
        self.b = nn.Parameter(torch.Tensor(1, num_blocks, dout))

        stdv = math.sqrt(6.0) / math.sqrt(din + dout)
        nn.init.uniform_(self.w, -stdv, stdv)
        nn.init.zeros_(self.b)

    def forward(self,x):
        # x - (bsz, num_blocks, din)
        x = x.permute(1,0,2)
        x = torch.bmm(x, self.w)
        x = x.permute(1,0,2)

        if self.bias:
            x = x + self.b

        return x

def _flatten_first_two(x):
    if x is None:
        return x
    return x.view([x.shape[0] * x.shape[1]] + list(x.shape[2:]))

# Input: x: (bs*k,*)
# Output: y: (bs,k,*)
def _unflatten_first(x, k):
    if x is None:
        return x
    return x.view([-1, k] + list(x.shape[1:]))

def get_positional(size):
    up = np.arange(0, size) / (size - 1)
    down = (size - 1 - up) / (size - 1)

    up = torch.from_numpy(up)
    down = torch.from_numpy(down)

    if torch.cuda.is_available():
        up = up.cuda()
        down = down.cuda()

    a = up.unsqueeze(0).repeat(size, 1)
    b = up.unsqueeze(1).repeat(1, size)
    c = down.unsqueeze(0).repeat(size, 1)
    d = down.unsqueeze(1).repeat(1, size)

    out = torch.stack((a, b, c, d)).float()

    return out

class Encoder(nn.Module):
    def __init__(self, channels=3, img_size=128, latent_size=64):
        super().__init__()
        self.model  = nn.Sequential(nn.Conv2d(channels, 32, 5, 1, 2),
                                    #nn.BatchNorm2d(32),
                                    nn.ReLU(),
                                    nn.Conv2d(32, 32, 5, 1, 2),
                                    #nn.BatchNorm2d(32),
                                    nn.ReLU(),
                                    nn.Conv2d(32, 32, 5, 1, 2),
                                    #nn.BatchNorm2d(32),
                                    nn.ReLU(),
                                    nn.Conv2d(32, latent_size, 5, 1, 2),
                                    nn.ReLU()
                                    )
        self.pe = get_positional(img_size).unsqueeze(0)
        self.lin = nn.Sequential(nn.Linear(latent_size + 4, latent_size), nn.ReLU(),
                                 nn.Linear(latent_size, latent_size))
        self.layer_norm = nn.LayerNorm([latent_size+4, img_size, img_size])

    def forward(self, img):
        bs = img.size()[0]
        out = self.model(img)
        k = out.size()[2]*out.size()[3]
        pe = self.pe.repeat(bs, 1, 1, 1).to(out.device)
        features = torch.cat((out, pe), 1)
        features = self.layer_norm(features)
        set = features.view(bs,-1, k).transpose(-1, 1)
        out = self.lin(set)
        return out

class BasicModel(nn.Module):
    def __init__(self, args, name):
        super(BasicModel, self).__init__()
        self.name=name

    def train_(self, input_img, input_qst, label):
        self.optimizer.zero_grad()
        output = self(input_img, input_qst)
        loss = F.nll_loss(output, label)
        loss.backward()
        self.optimizer.step()
        pred = output.data.max(1)[1]
        correct = pred.eq(label.data).cpu().sum()
        accuracy = correct * 100. / len(label)
        return accuracy, loss
        
    def test_(self, input_img, input_qst, label):
        output = self(input_img, input_qst)
        loss = F.nll_loss(output, label)
        pred = output.data.max(1)[1]
        correct = pred.eq(label.data).cpu().sum()
        accuracy = correct * 100. / len(label)
        return accuracy, loss

    def save_model(self, epoch, name):
        torch.save(self.state_dict(), f'{name}/epoch_{epoch:02d}.pth')

class SlotAttention(nn.Module):
    def __init__(self, num_slots, dim, iters = 3, eps = 1e-8, hidden_dim = 128):
        super().__init__()
        self.num_slots = num_slots
        self.iters = iters
        self.eps = eps
        self.scale = dim ** -0.5

        self.slots_mu = nn.Parameter(torch.randn(1, 1, dim))
        self.slots_sigma = nn.Parameter(torch.randn(1, 1, dim))

        self.to_q = nn.Linear(dim, dim, bias=False)
        self.to_k = nn.Linear(dim, dim, bias=False)
        self.to_v = nn.Linear(dim, dim, bias=False)

        self.gru = nn.GRUCell(dim, dim)

        hidden_dim = max(dim, hidden_dim)

        self.mlp = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.ReLU(inplace = True),
            nn.Linear(hidden_dim, dim)
        )

        self.norm_input  = nn.LayerNorm(dim)
        self.norm_slots  = nn.LayerNorm(dim)
        self.norm_pre_ff = nn.LayerNorm(dim)

    def forward(self, inputs, num_slots = None):
        b, n, d = inputs.shape
        n_s = num_slots if num_slots is not None else self.num_slots
        
        mu = self.slots_mu.expand(b, n_s, -1)
        sigma = self.slots_sigma.expand(b, n_s, -1)
        slots = torch.normal(mu, sigma)

        inputs = self.norm_input(inputs)        
        k, v = self.to_k(inputs), self.to_v(inputs)

        for _ in range(self.iters):
            slots_prev = slots

            slots = self.norm_slots(slots)
            q = self.to_q(slots)

            dots = torch.einsum('bid,bjd->bij', q, k) * self.scale
            attn = dots.softmax(dim=1) + self.eps
            attn = attn / attn.sum(dim=-1, keepdim=True)

            updates = torch.einsum('bjd,bij->bid', v, attn)

            slots = self.gru(
                updates.reshape(-1, d),
                slots_prev.reshape(-1, d)
            )

            slots = slots.reshape(b, -1, d)
            slots = slots + self.mlp(self.norm_pre_ff(slots))

        return slots

class Self_Attention(nn.Module):
    def __init__(self, dim, att_dim, nheads=4):
        super(Self_Attention, self).__init__()

        self.dim = dim
        self.att_dim = att_dim
        self.nheads = nheads

        self.query_net = nn.Linear(dim, att_dim * nheads)
        self.key_net = nn.Linear(dim, att_dim * nheads)
        self.value_net = nn.Linear(dim, att_dim * nheads)

        self.final = nn.Linear(att_dim * nheads, dim)

        self.res = nn.Sequential(
            nn.Linear(dim,2 * dim),
            nn.ReLU(),
            nn.Linear(2 * dim, dim)
        )
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

    def forward(self, query, key, value):
        bsz, n_read, _ = query.shape
        _, n_write, _ = key.shape

        res = query

        q = self.query_net(query).reshape(bsz, n_read, self.nheads, self.att_dim)
        q = q.permute(0,2,1,3) / np.sqrt(self.att_dim)
        k = self.key_net(key).reshape(bsz, n_write, self.nheads, self.att_dim)
        k = k.permute(0,2,3,1)
        v = self.value_net(value).reshape(bsz, n_write, self.nheads, self.att_dim)
        v = v.permute(0,2,1,3)

        score = F.softmax(torch.matmul(q,k), dim=-1) # (bsz, nheads, n_read, n_write)

        out = torch.matmul(score, v) # (bsz, nheads, n_read, att_dim)
        out = out.view(bsz, self.nheads, n_read, self.att_dim)

        out = out.permute(0, 2, 1, 3).reshape(bsz, n_read, self.nheads * self.att_dim)
        out = self.final(out)

        out = self.norm1(res + out)
        out = self.norm2(out + self.res(out))

        return out

class Compositional_Self_Attention(nn.Module):
    def __init__(self, dim, att_dim, nheads=4, nrules=1, dot=False):
        super(Compositional_Self_Attention, self).__init__()

        self.dim = dim
        self.att_dim = att_dim
        self.nheads = nheads
        self.nrules = nrules
        self.qk_dim = 16
        self.dot = dot

        self.query_net = nn.Linear(dim, att_dim * nheads)
        self.key_net = nn.Linear(dim, att_dim * nheads)
        self.value_net = nn.Linear(dim, att_dim * nrules)

        self.query_value_net = nn.Linear(dim, att_dim * nheads)

        if dot:
            self.key_value_net = nn.Linear(att_dim, att_dim)
        else:
            self.score_network = nn.Linear(2*att_dim, 1)

        self.final = nn.Linear(att_dim * nheads, dim)

        self.res = nn.Sequential(
            nn.Linear(dim,2 * dim),
            nn.ReLU(),
            nn.Linear(2 * dim, dim)
        )
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

    def forward(self, query, key, value):
        bsz, n_read, _ = query.shape
        _, n_write, _ = key.shape

        res = query

        q = self.query_net(query).reshape(bsz, n_read, self.nheads, self.att_dim)
        q = q.permute(0,2,1,3) / np.sqrt(self.att_dim)
        k = self.key_net(key).reshape(bsz, n_write, self.nheads, self.att_dim)
        k = k.permute(0,2,3,1)
        v = self.value_net(value).reshape(bsz, n_write, self.nrules, self.att_dim)
        v = v.permute(0,2,1,3).unsqueeze(1)

        score = F.softmax(torch.matmul(q,k), dim=-1).unsqueeze(2) # (bsz, nheads, n_read, n_write)

        out = torch.matmul(score, v) # (bsz, nheads, nrules, n_read, att_dim)
        out = out.view(bsz, self.nheads, self.nrules, n_read, self.att_dim)

        out = out.permute(0, 3, 1, 2, 4).reshape(bsz, n_read, self.nheads, self.nrules, self.att_dim)

        if self.dot:
            q_v = self.query_value_net(query).reshape(bsz, n_read, self.nheads, 1, self.att_dim) / np.sqrt(self.att_dim)
            k_v = self.key_value_net(out).reshape(bsz, n_read, self.nheads, self.nrules, self.att_dim)

            comp_score = F.softmax(torch.matmul(q_v, k_v.transpose(4,3)), dim=-1).reshape(bsz, n_read, self.nheads, self.nrules, 1)
        else:
            q_v = self.query_value_net(query).reshape(bsz, n_read, self.nheads, 1, self.att_dim).expand(-1, -1, -1, self.nrules, -1)
            in_ = torch.cat((q_v, out), dim=-1)
            comp_score = F.softmax(self.score_network(in_), dim=3)

        out = (comp_score * out).sum(dim=3).reshape(bsz, n_read, self.att_dim * self.nheads)

        out = self.final(out)

        out = self.norm1(res + out)
        out = self.norm2(out + self.res(out))

        return out

class SlotRN(BasicModel):
    def __init__(self, args):
        super(SlotRN, self).__init__(args, 'SlotRN')

        self.transformer_dim = args.transformer_dim
        self.heads = args.n_heads
        self.rules = args.n_rules
        self.att_dim = args.att_dim
        self.relation_type = args.relation_type
        self.iterations = args.iterations
        self.model = args.model
        self.dot = args.dot

        self.conv = Encoder(channels=3, img_size=75, latent_size=48)

        self.slot_attention = SlotAttention(num_slots=10, dim=48)
        self.map_rnn = nn.Linear(18, self.transformer_dim)
        self.map_conv = nn.Linear(48 + self.transformer_dim, self.transformer_dim)

        if self.model == 'Transformer':
            self.transformer = Self_Attention(self.transformer_dim, self.att_dim, self.heads)
        elif self.model == 'Compositional':
            self.transformer = Compositional_Self_Attention(self.transformer_dim, self.att_dim, self.heads, self.rules, self.dot)

        self.final = nn.Sequential(
            nn.Linear(self.transformer_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.Dropout(),
            nn.ReLU(),
            nn.Linear(256, 10)
        )

        self.optimizer = optim.Adam(self.parameters(), lr=args.lr)

    def forward(self, img, qst):
        x = self.conv(img)
        x = self.slot_attention(x)

        q = self.map_rnn(qst).unsqueeze(1)
        q_repeat = q.repeat(1, x.size()[1], 1)
        x = torch.cat([x, q_repeat], dim=-1)

        x = self.map_conv(x)
        x = torch.cat([q, x], 1)

        y = x

        for _ in range(self.iterations):
            y = self.transformer(y, y, y)

        y = self.final(y[:,0,:])
        return F.log_softmax(y, dim=1)